Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add step method state and make step results deterministic with respect to it #7508

Merged
merged 7 commits into from
Oct 7, 2024

Conversation

lucianopaz
Copy link
Contributor

@lucianopaz lucianopaz commented Sep 18, 2024

Description

The original intent of this PR was to fully address #7503. That proved to be a very long task, so the current PR focuses only on closing #5797 and to add sampling_state to all pymc step methods. I'll leave the past description further down.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7508.org.readthedocs.build/en/7508/

Original intent (OUTDATED)

This will be a long PR, but I want to open the discussion of the design in its early phases. The overall goal is to provide the ability to pause and later resume sampling that is based on pymc step methods. Once this PR is finished, I hope that we'll get into the problem of adding this ability when sampling with blackjax, nutpie and numpyro.

There will be 4 subgoals in this PR. I'll write them down and list the tasks that I'll do in each:

  1. Write something that can dump or load the step method's sampling state.
  • Create a class that represents the sampling state
  • Get to dump or load the state for metropolis step methods
  • Get to dump or load the state for compound step methods
  • Get to dump or load the state for slice step methods
  • Get to dump or load the state for HMC step methods
  • Get to dump or load the state for NUTS step methods
  • Ensure sampling state completely determines the step method's results
  1. Write something that can dump or load the trace in its intermediate stages.
  • Get to dump or load the trace for NDArray backend
  • Get to dump or load the trace for McBackend (this might be done already?)
  • Get to dump or load the state for Arviz backend.
  1. Add a way to restart sampling using previous sampling states and traces
  • Single chain MCMC
  • Parallel MCMC
  • Population sampling
  1. Add async sampling methods (like nutpie's non blocking)
  • Single chain MCMC
  • Parallel MCMC
  • Population sampling

@lucianopaz lucianopaz added enhancements major Include in major changes release notes section samplers labels Sep 18, 2024
@lucianopaz lucianopaz force-pushed the checkpoints branch 5 times, most recently from 8782ed3 to 8c44fd0 Compare September 19, 2024 12:15
Copy link

codecov bot commented Sep 19, 2024

Codecov Report

Attention: Patch coverage is 97.43590% with 11 lines in your changes missing coverage. Please review.

Project coverage is 92.69%. Comparing base (af5ea5c) to head (af74f2c).
Report is 14 commits behind head on main.

Files with missing lines Patch % Lines
pymc/step_methods/hmc/quadpotential.py 96.15% 4 Missing ⚠️
pymc/sampling/mcmc.py 83.33% 3 Missing ⚠️
pymc/sampling/parallel.py 83.33% 1 Missing ⚠️
pymc/step_methods/metropolis.py 99.02% 1 Missing ⚠️
pymc/step_methods/state.py 98.11% 1 Missing ⚠️
pymc/util.py 93.75% 1 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7508      +/-   ##
==========================================
+ Coverage   92.43%   92.69%   +0.25%     
==========================================
  Files         103      104       +1     
  Lines       17109    17402     +293     
==========================================
+ Hits        15814    16130     +316     
+ Misses       1295     1272      -23     
Files with missing lines Coverage Δ
pymc/math.py 77.71% <100.00%> (ø)
pymc/sampling/population.py 74.67% <100.00%> (+1.12%) ⬆️
pymc/step_methods/arraystep.py 94.66% <100.00%> (+0.07%) ⬆️
pymc/step_methods/compound.py 97.56% <100.00%> (+0.38%) ⬆️
pymc/step_methods/hmc/base_hmc.py 91.91% <100.00%> (+1.15%) ⬆️
pymc/step_methods/hmc/hmc.py 93.54% <100.00%> (+0.69%) ⬆️
pymc/step_methods/hmc/nuts.py 97.40% <100.00%> (+0.12%) ⬆️
pymc/step_methods/slicer.py 96.70% <100.00%> (+0.31%) ⬆️
pymc/step_methods/step_sizes.py 80.95% <100.00%> (+5.95%) ⬆️
pymc/sampling/parallel.py 88.50% <83.33%> (ø)
... and 5 more

... and 12 files with indirect coverage changes

@lucianopaz
Copy link
Contributor Author

I think that this PR now has a good coverage for the first subgoal that I talked about in the description. All of pymc step methods now expose a sampling_state property that fully determines the following draws that the sampler will get. To do this, and also to test it, I had to do one major change: I had to make all of the step methods become detached from numpy's global random state. This doesn't bear significant consequences for an average user, but it opens up many possibilities:

  • samplers could run concurrently if they use different step methods
  • to ensure reproducibility, one doesn't need to affect the global state, which might have consequences on other packages that are out of pymc's scope
  • the init_nuts step method needs to be adapted down the line to use random generators instead of a list of random seeds.

I'll ping some people to get some feedback on the current situation, because I've invested quite a lot of time into getting it right (and also have a clean git history), and I would like to get reviews before moving on. Pinging @ricardoV94, @michaelosthege, @Armavica, @ColCarroll and @junpenglao.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Luciano, had a quick glance. Looks alright but I wonder if we can make it a tad simpler by using dataclasses instead of reinventing something that seems similar?

I left other questions as comments.

Thanks for the initiative, making step methods less opaque will be great for customizing/controlling sampling

pymc/step_methods/arraystep.py Show resolved Hide resolved
pymc/step_methods/hmc/base_hmc.py Show resolved Hide resolved
_num_divs_sample: int


class BaseHMC(GradientSharedStep, WithSamplingState):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need a separate class? Why not be part of the baseclass?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which baseclass? WithSamplingState? Or BaseHMCState? If you mean the latter, it's because the BaseHMC step method has properties that are different from other step methods, so I need to represent its state differently. If you mean the WithSamplingState, that's because the WithSamplingState provides the sampling_state property accessors to the step method. Let me know if you meant something else.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean all step methods should have whatever WithSamplingState implements in the base class

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I think you’re right. It must be a left over from a previous state of my commits

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would consider a mixin to be a cleaner design, because it doesn't make third party step methods inherit WithSamplingState even if they aren't compatible. It should also make things easier to test by not introducing cross-dependencies.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any compatibility issues it a step sampler spent make use of the functionality in the base class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see what @michaelosthege is saying. The current design is more inline with what @ricardoV94 said, I already have a BlockedStep as a subclass of WithSamplingState, so including it in BaseHMC ancestors explicitly is pointless. The base StepMethodState only tries to access the rng property. I could provide a default value factory for that property, and that way there won't be any problem for third party libraries that define their own step methods. They would simply get a useless random number Generator object. They could use it if they wanted to though.

import numpy as np


class MetaDataClassState(type):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this doing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I already said it somewhere else but I'll add it here too. I wanted to make the step method states simple dataclass wrapped classes. I ran into problems revolving around __eq__ and then also around positional arguments being defined after arguments with default values. To avoid these problems, I had to add some arguments to the dataclass decorator and I would sometimes forget them, leading to errors down the line. The simplest solution that didn't involve having to always write down boilerplate code with every State definition was to add this metaclass. It simply creates the subclass type and then wraps that using the dataclass(eq=False, kw_only=True) decorator. That way, you just need to inherit from DataClassState and you'll get the guarantee of working with a dataclass that has the specially crafted __eq__ method.

return dataclass(eq=False, kw_only=True)(super().__new__(cls, name, bases, namespace))


class DataClassState(metaclass=MetaDataClassState):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? Can we do simpler than this?

Can we just use a vanilla dataclass?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I already answered this above. This solution avoids having to add boilerplate code all around the State definitions. The main problem is that if you have a class that inherits from another class that uses the @dataclass decorator, the subclass wont be a "dataclass". With the metaclass approach, the subclasses will also be "dataclass" types and they will use the nice __eq__ method that I had to write here.

pymc/step_methods/state.py Outdated Show resolved Hide resolved
kwargs[field.name] = deepcopy(_val)
return state_class(**kwargs)

@sampling_state.setter
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When is the setter used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the moment, it's only used in the tests. Calling step_method.sampling_state = state will set the step method's attributes to the proper values represented by the provided state. The goal is that we'll use it when we want to jump start the sampler to some past sampling state. The workflow I have in mind is this:

  1. Have the model already built
  2. Have the samplers already built
  3. If someone provides a past sampling state from which to start, set the sampler's state to that
  4. Start sampling using the sampler's set state and collect the results

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer an explicit method for that functionality? Something like step_method.set_state(state)?

It's 100% subjective opinion though

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree to disagree on this one. The set_state method looks like any other method call and it’s not explicitly saying that it won’t have any return value. The attribute assignment syntax on the other hand is much more explicit on its intent.

Copy link
Member

@ricardoV94 ricardoV94 Sep 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So why is there a set_rng? The same argument you used against set_state would apply? Also, in my experience properties have always be a PITA because at some point we figure out we would like kwargs/customization and there's no way to refactor a property into a method with back_compat.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a set_rng method because the subclasses are supposed to overload it. The HMC step methods have their own rng and their potential also have a spawned rng. Setting the rng had to work differently than with the rest of the step methods, so I decided to make it a method instead of a property. Anyway, I still prefer the sampling_state as a property mixin. If we eventually realize that we don't want to, we can always add a new set_state method and issue a DeprecationWarning or RuntimeError in the property setter.

Copy link
Member

@ricardoV94 ricardoV94 Sep 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me try one last argument. Setting a property (which looks like an attribute) does not make me intuitively think that it will actually affect the sampler. Specially in this case, where the property (sampler.sample_sate) is actually a read-only copy of the internal state, not the state itself.

It's like if you have a PyMC model, which has the attribute model.datalogp. I wouldn't expect model.datalogp = 0 to be a valid way of overriding the model logp. Of course if I read the source code I'll be able to figure out, but from just reading use code I wouldn't think it would do what I want it to do.

I find sampler.set_rng() must more obvious that will actually affect the rng used. And sampler.set_state() that it will actually affect the state used.

Final argument I found, but not necessarily care about has to do with inheritance. Calling super() on a property is clumsy, if you want to combine the effects of the base class method and some tweaking in the inherited class.


def get_random_generator(seed: RandomGenerator = None, copy=True) -> np.random.Generator:
if copy:
seed = deepcopy(seed)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When/why is copying the seed needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When seed is a numpy.random.Generator. If you don't do that, numpy.random.default_rng(seed) will return seed. If you provide a BitGenerator, the Generator object itself will be new, but its BitGenerator will be the same object that you passed in, making it potentially shared with another Generator. To ensure that those two scenarios wont happen, I deepcopy the seed by default.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an explicit instance check for Generator? And/or comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ll add a comment but I found that numpy does all of the hard type checking work and it felt like a waste to repeat

@ricardoV94
Copy link
Member

to ensure reproducibility, one doesn't need to affect the global state, which might have consequences on other packages that are out of pymc's scope

This is a desired feature not a problem. There's an old issue that you can link to/ close if we get rid of the global RNG with this PR

samplers could run concurrently if they use different step methods

How? Don't they always require interleaving steps (ie conditoned on a valid state?)

@lucianopaz
Copy link
Contributor Author

Hi Luciano, had a quick glance. Looks alright but I wonder if we can make it a tad simpler by using dataclasses instead of reinventing something that seems similar?

I left other questions as comments.

Thanks for the initiative, making step methods less opaque will be great for customizing/controlling sampling

Thanks for the review @ricardoV94! I am using dataclass. The original approach was to add the @dataclass decorator to every state class. The problem there was that __eq__ didn't work well with numpy arrays or with random number generators. That's why my second approach was to have a base class (DataClassState) that was a dataclass with an __eq__ method. The problem with that approach was that I also had to add the @dataclass(eq=False) decorator to every state class that inherited from that common base. While working on this, I frequently forgot to add either the decorator or the eq=False, leading to failures down the line. That's why, in the final and current approach, I have a metaclass that will call dataclass(eq=False)(cls) on any class that inherits from the DataClassState base class. That way, I can ensure that all child classes are handled as dataclasses, and also use the base class __eq__ that handles the weird array and random generator types. One last thing was that I also had to add kw_only to the dataclass, to enable safe inheritance between the classes. If I didn't, some classes that had default values for some fields could not be used as ancestors because the auto generated __init__ would put those before other mandatory positional arguments.

@lucianopaz
Copy link
Contributor Author

to ensure reproducibility, one doesn't need to affect the global state, which might have consequences on other packages that are out of pymc's scope

This is a desired feature not a problem. There's an old issue that you can link to/ close if we get rid of the global RNG with this PR

Ouch, it looks like I'm a bit of a broken record... I opened #5797 more than 2 years ago.

samplers could run concurrently if they use different step methods

How? Don't they always require interleaving steps (ie conditoned on a valid state?)

At the moment, step methods need their own global random state to be able to run concurrently. That means that the samplers were limited to using different processes with their own random state to work. I'm not sure how fork or forkserver work with numpy's global numpy.random.mtrand._rand state, but if they somehow share it and somehow use locks to ensure that they don't break the state with race conditions, the results of sampling wouldn't be deterministic based on the seed, because one chain could draw a sample faster than another at some times and slower than another at other times, making it use different random states to generate samples from. But even if fork and forkserver multiprocessing produce copies of the global state that are unique to each process, if there is some concurrent thread in the process that touches upon the global random state, it would affect the potential draws from the step method. And likewise, the steps from pymc would affect the global state, indirectly affecting other things that might rely on it. Just to make things clear, none of this means that pymc or other concurrent stuff would be breaking the sampling, I just mean to say that the sampling results could be affected because the global random state could be changed in the middle of sampling by anyone. What I did with this PR was to isolate the step methods from anything else, ensuring that they won't interact or interfere with other things that we don't control and that our users might not even be aware of.

@ricardoV94
Copy link
Member

The problem there was that eq didn't work well with numpy arrays or with random number generators.

Why do we need eq?

@lucianopaz
Copy link
Contributor Author

Why do we need eq?

For convenience. If we need to assert equality, it’s much better to have this method

@ricardoV94
Copy link
Member

Why do we need eq?

For convenience. If we need to assert equality, it’s much better to have this method

Why not wait until we see a need then?

@lucianopaz
Copy link
Contributor Author

Why do we need eq?

For convenience. If we need to assert equality, it’s much better to have this method

Why not wait until we see a need then?

I did need it in all of the tests I wrote

@ricardoV94
Copy link
Member

ricardoV94 commented Sep 23, 2024

Why do we need eq?

For convenience. If we need to assert equality, it’s much better to have this method

Why not wait until we see a need then?

I did need it in all of the tests I wrote

That's more an argument for a test utility than code we need to strictly maintain. Checking numpy array and random generator equality shows up in other scenarios

Copy link
Member

@michaelosthege michaelosthege left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! The refactoring of sampler RNGs could be extracted into its own PR and merged first?

Item 4. of your description sounds unrelated.

If I understand correctly, your approach with the mixin class has the nice benefit of not adding overhead to every iteration!

Regarding item 2. (where to dump traces) I previously pointed at the stats, because they already contain some/many state fields. In McBackend the stats can be sparse>, so one could emit a sampler_state state every 100 iterations or so.

  • Last time I checked ArviZ/xarray didn't support saving sparse arrays to disk. But that's for sure workaroundable at the ArviZ level.
  • ClickHouseBackend can persist sparse stats already.

@@ -292,7 +295,7 @@ def test_step_discrete(self):
unc = np.diag(C) ** 0.5
check = (("x", np.mean, mu, unc / 10.0), ("x", np.std, unc, unc / 10.0))
with model:
step = Metropolis(S=C, proposal_dist=MultivariateNormalProposal)
step = Metropolis(S=C, proposal_dist=MultivariateNormalProposal, rng=123456)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this seed different?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seeding these tests was a PITA. I kept running into sporadic errors and flakiness while I was polishing the step method detachment from the global random state, and I some rng's were left with the weird intermediate seeds.

@@ -36,6 +36,8 @@
from tests.helpers import RVsAssignmentStepsTester, StepMethodTester
from tests.models import mv_simple, mv_simple_discrete, simple_categorical

SEED = sum(ord(c) for c in "test_metropolis")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but why 😂

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above. I can only answer with an argentinian meme

imagen

Comment on lines 30 to 39
this_fields = set([f.name for f in fields(self)])
other_fields = set([f.name for f in fields(other)])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inner list comprehension is unnecessary

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not? The fields function will return Field objects, that have a bunch of extra dataclass specific attributes (e.g. type, default, default_factory). I just want to check that the names are the same and use those names later for getattr.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh! I think I understand what you're saying. A set of Field objects should already be enough to test if this_fields == other_fields because all of the other attributes should also match.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually that was not my point (but you might be right about it)

set(generator) aka set(a for a in "ABCD") works. You can leave out creating the inner list (and then iterating it again when creating the set)

return v1 == v2


class WithSamplingState:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add docstrings for the three new classes to explain how they fit together?

For WithSamplingState I understand that it's a mixin adding a sampling_state property which, upon access, returns a new container object of a DataClassState subtype. This container holds copies of field values of the WithSamplingState object. (?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's exactly right. I'll add comments and docstrings.

@ricardoV94
Copy link
Member

If I understand correctly, your approach with the mixin class has the nice benefit of not adding overhead to every iteration!

What overhead? There's nothing computed every iteration

@lucianopaz
Copy link
Contributor Author

Great work! The refactoring of sampler RNGs could be extracted into its own PR and merged first?

Thanks @michaelosthege! I think that it could be extracted. I would need to add docstring entries for rng first though.

Regarding item 2. (where to dump traces) I previously pointed at the stats, because they already contain some/many state fields. In McBackend the stats can be sparse>, so one could emit a sampler_state state every 100 iterations or so.

* Last time I checked ArviZ/xarray didn't support saving sparse arrays to disk. But that's for sure workaroundable at the ArviZ level.

* `ClickHouseBackend` can persist sparse stats already.

Thanks for the pointers. I'll try to use that once I arrive at point 2

@lucianopaz
Copy link
Contributor Author

If I understand correctly, your approach with the mixin class has the nice benefit of not adding overhead to every iteration!

What overhead? There's nothing computed every iteration

I think that Michael means that I'm not building a StepMethodState at each step. The state is a property that can be built on demand, but there is no extra compute involved in regular sampling. The step method just chugs along until we eventually call step.sampling_state at some point during sampling. On the other hand, stats objects are built at each step, so they produce some overhead when compared to a situation in which only the samples got collected.

@michaelosthege
Copy link
Member

On the other hand, stats objects are built at each step, so they produce some overhead when compared to a situation in which only the samples got collected.

The stats are just dicts and as far as I can tell you didn't change anything about how they are collected (alongside draws) in every iteration.

@lucianopaz lucianopaz force-pushed the checkpoints branch 2 times, most recently from aa4f007 to 92d0845 Compare September 25, 2024 12:19
@lucianopaz
Copy link
Contributor Author

@michaelosthege, @ricardoV94. I may have to pivot to doing other stuff for some time. I wouldn't want for this PR to become stale and difficult to rebase onto a future state of main. I think that what has been implemented so far is good enough to merge into main. The main points are:

  • Step methods now have a numpy.random.Generator object that they use to make steps
  • Step methods are now detached from the global numpy RandomState and their draws are completely determined by their own generator
  • pymc.sample still uses random_seed and the RandomState to choose initial points and jittering, but all of the step methods and potentials produced down the line use their own Generator object. So once, sampling has an initial point per chain, the global random state becomes irrelevant.
  • Step methods have the sampling_state property that determines completely the future draws that the stepper will produce.

All of this work is a major change from the current state of afairs in pymc, and I consider that it closes #5797. It doesn't close #7503, but I'll get back to that issue later. To fully address that, we need to discuss a couple of important intermediate things:

  1. How to store the sampling state to disk?
  2. How to store sampled draws iteratively to disk?
  3. If there are any changes to the default backend that we want to make to support 2 or if we should move to something like McBackend and protocol buffers?

I'll update this PR's description and mark it as ready for review and if you guys agree, we'll continue with the full checkpoint support in a near future.

@lucianopaz lucianopaz marked this pull request as ready for review September 26, 2024 10:02
@ricardoV94
Copy link
Member

Sounds good to me, can you add a more informative PR title?

@ricardoV94
Copy link
Member

BTW I still favor going with dataclasses instead of the custom new classes and have the equality as a detached test utility. There is no functionality that depends on equality right now or in the foreseeable future, unless I missed something.

Complexity for testing purposes seems backwards to me.

return False
if isinstance(v1, (list, tuple)): # noqa: UP038
return len(v1) == len(v2) and all(
DataClassState.compare_values(v1i, v2i) for v1i, v2i in zip(v1, v2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could use zip(..., strict=True) and avoid explicitly comparing the length

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This actually wasn’t good because strict=True raises a ValueError, and I just want it to return False. I’ll keep the length comparison as it was

@lucianopaz lucianopaz changed the title Checkpoints Add step method state and make step results deterministic with respect to it Sep 26, 2024
@lucianopaz
Copy link
Contributor Author

BTW I still favor going with dataclasses instead of the custom new classes and have the equality as a detached test utility. There is no functionality that depends on equality right now or in the foreseeable future, unless I missed something.

Complexity for testing purposes seems backwards to me.

I think that I can have a workaround for part of this, but it really depends on what you mean by dataclasses. Is the problem that I'm using a metaclass? Is it that I'm defining an __eq__? Is the problem that I'm using inheritance between DataClassState and some of its subclasses?

If your problem is that I'm relying on metaclasses, I think that I can do something differently to avoid them. If the problem is the __eq__, I think that I can move that to the tests as a function. What I don't want to change is the inheritance from DataClassState. That is important for static type analysis and also helped with reducing the amount of duplicate code.

@ricardoV94
Copy link
Member

Yes my problem is the implementation of __eq__ and metaclass. I don't see why we need this for what is basically a namedtuple. We hold RNG / np.ndarray in many kinds of objects and we don't usually go about implementing equality if we don't have to (we do for TensorConstants in PyTensor for example)

@lucianopaz
Copy link
Contributor Author

Yes my problem is the implementation of __eq__ and metaclass. I don't see why we need this for what is basically a namedtuple. We hold RNG / np.ndarray in many kinds of objects and we don't usually go about implementing equality if we don't have to (we do for TensorConstants in PyTensor for example)

@ricardoV94, I changed the code to avoid the metaclass and detached the __eq__ from the code. I did had to leave a comparison utility function in the main codebase because I want to be able to compare frozen fields. Let me know if you think that's good enough. If it is, I'll clean up the last commit and we can merge. I chose to leave it dirty for now to make it easier to undo the change if we don't like it.

@ricardoV94
Copy link
Member

ricardoV94 commented Sep 27, 2024

@lucianopaz sounds good to me. Besides my personal preference for methods vs setter, I have one last suggestion and one question.

Rename compare_dataclass_values and compare_states to equal_dataclass_values, equal_states (or whatever it was called, point being compare -> equal.

What's the deal with frozen fields? Why do we need them / to worry about them?

To be clear, I'm happy with the state and I am not blocking the merge after the rebase.

@lucianopaz
Copy link
Contributor Author

lucianopaz commented Sep 27, 2024

Rename compare_dataclass_values and compare_states to equal_dataclass_values, equal_states (or whatever it was called, point being compare -> equal.

Good point, I'll do that.

What's the deal with frozen fields? Why do we need them / to worry about them?

The step methods have a bunch of extra information in them that gets set when they are created. My very first idea was to try to have some kind of pickle.dump approach where the entire step method would be stored. The problem with that was that the step methods have references to model variables and compiled functions. Serializing the step method would then force us to also serialize the whole model instance along with some compiled functions. This extra burden made me think that it would be better to only store and set some small bits of information from the steppers. Part of this information changes as samples are drawn, and other parts don't.

The long term goal was to be able to set the step method to a state where it could continue sampling as it had been doing before. Since I wouldn't be able to rebuild the full step method from what I save to disk, I needed to add some way to determine that the stored state was compatible with the step method that was being modified. That's why I decided to include some step information that doesn't change during sampling as frozen fields. If the saved state doesn't match with the stepper's frozen fields, that means that the state is not valid for the step method and an error should be raised.

@lucianopaz lucianopaz merged commit 465d8ac into pymc-devs:main Oct 7, 2024
20 checks passed
@lucianopaz lucianopaz deleted the checkpoints branch October 7, 2024 08:00
ricardoV94 added a commit to ricardoV94/pymc that referenced this pull request Oct 8, 2024
PRs pymc-devs#7508 and pymc-devs#7492 introduced incompatible changes but were not tested simultaneously.

Deepcopying the steps in the tests leads to deepcopying the model which uses `clone_model`, which in turn does not support initvals.
ricardoV94 added a commit to ricardoV94/pymc that referenced this pull request Oct 8, 2024
PRs pymc-devs#7508 and pymc-devs#7492 introduced incompatible changes but were not tested simultaneously.

Deepcopying the steps in the tests leads to deepcopying the model which uses `clone_model`, which in turn does not support initvals.
ricardoV94 added a commit that referenced this pull request Oct 8, 2024
PRs #7508 and #7492 introduced incompatible changes but were not tested simultaneously.

Deepcopying the steps in the tests leads to deepcopying the model which uses `clone_model`, which in turn does not support initvals.
@lucianopaz lucianopaz mentioned this pull request Oct 16, 2024
19 tasks
mkusnetsov pushed a commit to mkusnetsov/pymc that referenced this pull request Oct 26, 2024
PRs pymc-devs#7508 and pymc-devs#7492 introduced incompatible changes but were not tested simultaneously.

Deepcopying the steps in the tests leads to deepcopying the model which uses `clone_model`, which in turn does not support initvals.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancements major Include in major changes release notes section samplers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Refactor step methods to use their own random stream
3 participants